{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import json\n",
    "from keras.models import Model\n",
    "from keras.layers import Input\n",
    "from keras.layers.convolutional import Conv2D\n",
    "from keras.layers.pooling import MaxPooling2D, AveragePooling2D\n",
    "from keras.layers.normalization import BatchNormalization\n",
    "from keras import backend as K\n",
    "from collections import OrderedDict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def format_decimal(arr, places=6):\n",
    "    return [round(x * 10**places) / 10**places for x in arr]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "DATA = OrderedDict()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### pipeline 11"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "random_seed = 1011\n",
    "data_in_shape = (8, 8, 2)\n",
    "\n",
    "layers = [\n",
    "    Conv2D(4, (3,3), activation='relu', padding='valid', strides=(1,1), data_format='channels_last', use_bias=True),\n",
    "    AveragePooling2D(pool_size=(2,2), strides=(1,1), padding='valid', data_format='channels_last')\n",
    "]\n",
    "\n",
    "input_layer = Input(shape=data_in_shape)\n",
    "x = layers[0](input_layer)\n",
    "for layer in layers[1:-1]:\n",
    "    x = layer(x)\n",
    "output_layer = layers[-1](x)\n",
    "model = Model(inputs=input_layer, outputs=output_layer)\n",
    "\n",
    "np.random.seed(random_seed)\n",
    "data_in = 2 * np.random.random(data_in_shape) - 1\n",
    "\n",
    "# set weights to random (use seed for reproducibility)\n",
    "weights = []\n",
    "for i, w in enumerate(model.get_weights()):\n",
    "    np.random.seed(random_seed + i)\n",
    "    weights.append(2 * np.random.random(w.shape) - 1)\n",
    "model.set_weights(weights)\n",
    "\n",
    "result = model.predict(np.array([data_in]))\n",
    "data_out_shape = result[0].shape\n",
    "data_in_formatted = format_decimal(data_in.ravel().tolist())\n",
    "data_out_formatted = format_decimal(result[0].ravel().tolist())\n",
    "\n",
    "DATA['pipeline_11'] = {\n",
    "    'input': {'data': data_in_formatted, 'shape': data_in_shape},\n",
    "    'weights': [{'data': format_decimal(w.ravel().tolist()), 'shape': w.shape} for w in weights],\n",
    "    'expected': {'data': data_out_formatted, 'shape': data_out_shape}\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### export for Keras.js tests"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "filename = '../../test/data/pipeline/11.json'\n",
    "if not os.path.exists(os.path.dirname(filename)):\n",
    "    os.makedirs(os.path.dirname(filename))\n",
    "with open(filename, 'w') as f:\n",
    "    json.dump(DATA, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{\"pipeline_11\": {\"input\": {\"data\": [0.311362, -0.228519, 0.253024, -0.775634, 0.616056, -0.044226, 0.253076, 0.500853, 0.984775, 0.394826, -0.895291, -0.177992, -0.882996, 0.328019, -0.129043, 0.607012, -0.734911, -0.751302, 0.56571, -0.01302, -0.568758, 0.479133, 0.343441, -0.064802, -0.706584, -0.44282, -0.832317, -0.874655, -0.695238, -0.827426, -0.894361, 0.32456, -0.61541, 0.831896, -0.272104, -0.785451, 0.838743, 0.979695, 0.761349, -0.32666, 0.157434, -0.259107, 0.686692, -0.468014, -0.853276, -0.456564, 0.408532, 0.658449, 0.803013, -0.442139, -0.34814, -0.405807, -0.92347, 0.975551, 0.948464, -0.707809, -0.721025, -0.767626, 0.62826, 0.670958, -0.053223, 0.398117, -0.527273, -0.787475, 0.768245, -0.442372, 0.420796, -0.855725, -0.875469, -0.864543, 0.170765, -0.726726, -0.915107, 0.779212, 0.449691, 0.976425, 0.321766, 0.960179, -0.099143, -0.867347, -0.141255, -0.593658, 0.108986, -0.042904, 0.08184, -0.728972, 0.58547, 0.040569, 0.704295, -0.228271, -0.664164, 0.541857, -0.124211, 0.411514, -0.552667, 0.194645, 0.312569, -0.128953, -0.464284, 0.391605, 0.751542, -0.523649, -0.332429, 0.527063, 0.533137, -0.494064, 0.911437, -0.23134, -0.89969, 0.988745, -0.849639, -0.191247, 0.925705, 0.8785, 0.610116, -0.741279, -0.163346, -0.707725, -0.073795, 0.58394, -0.860325, -0.89594, -0.299924, -0.946011, -0.83693, -0.601976, 0.234718, 0.367061], \"shape\": [8, 8, 2]}, \"weights\": [{\"data\": [0.311362, -0.228519, 0.253024, -0.775634, 0.616056, -0.044226, 0.253076, 0.500853, 0.984775, 0.394826, -0.895291, -0.177992, -0.882996, 0.328019, -0.129043, 0.607012, -0.734911, -0.751302, 0.56571, -0.01302, -0.568758, 0.479133, 0.343441, -0.064802, -0.706584, -0.44282, -0.832317, -0.874655, -0.695238, -0.827426, -0.894361, 0.32456, -0.61541, 0.831896, -0.272104, -0.785451, 0.838743, 0.979695, 0.761349, -0.32666, 0.157434, -0.259107, 0.686692, -0.468014, -0.853276, -0.456564, 0.408532, 0.658449, 0.803013, -0.442139, -0.34814, -0.405807, -0.92347, 0.975551, 0.948464, -0.707809, -0.721025, -0.767626, 0.62826, 0.670958, -0.053223, 0.398117, -0.527273, -0.787475, 0.768245, -0.442372, 0.420796, -0.855725, -0.875469, -0.864543, 0.170765, -0.726726], \"shape\": [3, 3, 2, 4]}, {\"data\": [-0.946541, 0.585593, -0.49527, 0.594532], \"shape\": [4]}], \"expected\": {\"data\": [0.385866, 1.127311, 0.86104, 0.585367, 0.385866, 1.461206, 0.404148, 0.665295, 1.182823, 0.661164, 0.14916, 0.893525, 1.22361, 0.125944, 0.0, 1.191117, 0.040787, 0.32045, 0.0, 1.735492, 1.208889, 1.18725, 0.254987, 1.936405, 1.208889, 1.604491, 0.254987, 1.498114, 0.419612, 0.417241, 0.0, 1.135943, 0.419612, 0.825821, 0.290601, 0.144114, 0.0, 1.517887, 0.290601, 0.311044, 0.828196, 0.667275, 0.183549, 1.676492, 1.489199, 0.35261, 0.183549, 2.09406, 0.661003, 0.047335, 0.791086, 1.453214, 0.0, 1.311635, 1.081687, 0.225057, 0.0, 2.307613, 0.290601, 0.178278, 1.022356, 0.82228, 0.472916, 0.340965, 1.125167, 0.682943, 0.472916, 1.140136, 1.125167, 0.795228, 1.067056, 0.908882, 0.0, 0.598098, 1.067056, 1.063713, 0.286897, 1.500481, 0.0, 1.379561, 1.076673, 0.620474, 0.464362, 0.336402, 0.573155, 1.000334, 0.289367, 0.321843, 0.690637, 1.814722, 0.27597, 0.165168, 0.117482, 1.158726, 0.27597, 1.595348, 0.548575, 0.982322, 0.0, 2.328448], \"shape\": [5, 5, 4]}}}\n"
     ]
    }
   ],
   "source": [
    "print(json.dumps(DATA))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
